{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "# default_exp problem_types.cls\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\""
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "# hide\n",
    "# test setup\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from m3tl.test_base import TestBase\n",
    "from m3tl.input_fn import train_eval_input_fn\n",
    "from m3tl.test_base import test_top_layer\n",
    "test_base = TestBase()\n",
    "params = test_base.params\n",
    "\n",
    "hidden_dim = params.bert_config.hidden_size\n",
    "\n",
    "train_dataset = train_eval_input_fn(params=params)\n",
    "one_batch = next(train_dataset.as_numpy_iterator())\n"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "2021-06-24 20:21:12.618 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_fake_ner, problem type: seq_tag\n",
      "2021-06-24 20:21:12.618 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_fake_multi_cls, problem type: multi_cls\n",
      "2021-06-24 20:21:12.619 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_fake_cls, problem type: cls\n",
      "2021-06-24 20:21:12.619 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_masklm, problem type: masklm\n",
      "2021-06-24 20:21:12.620 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_fake_regression, problem type: regression\n",
      "2021-06-24 20:21:12.620 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_fake_vector_fit, problem type: vector_fit\n",
      "2021-06-24 20:21:12.621 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem weibo_premask_mlm, problem type: premask_mlm\n",
      "2021-06-24 20:21:12.622 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem fake_contrastive_learning, problem type: contrastive_learning\n",
      "2021-06-24 20:21:12.622 | WARNING  | m3tl.base_params:assign_problem:647 - base_dir and dir_name arguments will be deprecated in the future. Please use model_dir instead.\n",
      "2021-06-24 20:21:12.624 | WARNING  | m3tl.base_params:prepare_dir:363 - bert_config not exists. will load model from huggingface checkpoint.\n",
      "2021-06-24 20:21:18.310 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train\n",
      "2021-06-24 20:21:18.372 | WARNING  | m3tl.read_write_tfrecord:chain_processed_data:258 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.\n",
      "2021-06-24 20:21:18.388 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing /tmp/tmpzuvcbd08/weibo_fake_cls_weibo_fake_ner_weibo_fake_regression_weibo_fake_vector_fit/train_00000.tfrecord\n",
      "2021-06-24 20:21:18.450 | WARNING  | m3tl.read_write_tfrecord:chain_processed_data:258 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.\n",
      "2021-06-24 20:21:18.465 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing /tmp/tmpzuvcbd08/weibo_fake_cls_weibo_fake_ner_weibo_fake_regression_weibo_fake_vector_fit/eval_00000.tfrecord\n",
      "2021-06-24 20:21:18.492 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing /tmp/tmpzuvcbd08/weibo_fake_multi_cls/train_00000.tfrecord\n",
      "2021-06-24 20:21:18.516 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing /tmp/tmpzuvcbd08/weibo_fake_multi_cls/eval_00000.tfrecord\n",
      "2021-06-24 20:21:18.595 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing /tmp/tmpzuvcbd08/weibo_masklm/train_00000.tfrecord\n",
      "2021-06-24 20:21:18.642 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing /tmp/tmpzuvcbd08/weibo_masklm/eval_00000.tfrecord\n",
      "2021-06-24 20:21:18.705 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing /tmp/tmpzuvcbd08/weibo_premask_mlm/train_00000.tfrecord\n",
      "2021-06-24 20:21:18.769 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing /tmp/tmpzuvcbd08/weibo_premask_mlm/eval_00000.tfrecord\n",
      "2021-06-24 20:21:18.787 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing /tmp/tmpzuvcbd08/fake_contrastive_learning/train_00000.tfrecord\n",
      "2021-06-24 20:21:18.805 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing /tmp/tmpzuvcbd08/fake_contrastive_learning/eval_00000.tfrecord\n",
      "2021-06-24 20:21:20.227 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: \n",
      "2021-06-24 20:21:20.228 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {\n",
      "    \"weibo_fake_cls_weibo_fake_ner_weibo_fake_regression_weibo_fake_vector_fit\": 0.20833333333333334,\n",
      "    \"weibo_fake_multi_cls\": 0.20833333333333334,\n",
      "    \"weibo_masklm\": 0.16666666666666666,\n",
      "    \"weibo_premask_mlm\": 0.20833333333333334,\n",
      "    \"fake_contrastive_learning\": 0.20833333333333334\n",
      "}\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Classification(cls)"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "Classification. By default this problem will use `[CLS]` token embedding.\n",
    "\n",
    "Example: `m3tl.predefined_problems.get_weibo_fake_cls_fn`."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Imports and utils\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "# export\n",
    "from functools import partial\n",
    "from typing import List\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from m3tl.base_params import BaseParams\n",
    "from m3tl.problem_types.utils import (empty_tensor_handling_loss,\n",
    "                                      nan_loss_handling)\n",
    "from m3tl.special_tokens import PREDICT, TRAIN\n",
    "from m3tl.utils import (LabelEncoder, get_label_encoder_save_path, get_phase,\n",
    "                        need_make_label_encoder, variable_summaries)\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Top Layer"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "# export\n",
    "\n",
    "class Classification(tf.keras.layers.Layer):\n",
    "    \"\"\"Classification Top Layer\"\"\"\n",
    "    def __init__(self, params: BaseParams, problem_name: str) -> None:\n",
    "        super(Classification, self).__init__(name=problem_name)\n",
    "        self.params = params\n",
    "        self.problem_name = problem_name\n",
    "        self.num_classes = self.params.get_problem_info(problem=problem_name, info_name='num_classes')\n",
    "        self.dense = tf.keras.layers.Dense(self.num_classes, activation=None)\n",
    "        self.metric_fn = tf.keras.metrics.SparseCategoricalAccuracy(\n",
    "            name='{}_acc'.format(self.problem_name))\n",
    "\n",
    "        self.dropout = tf.keras.layers.Dropout(1-params.dropout_keep_prob)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        mode = get_phase()\n",
    "        training = (mode == TRAIN)\n",
    "        feature, hidden_feature = inputs\n",
    "        hidden_feature = hidden_feature['pooled']\n",
    "        if mode != PREDICT:\n",
    "            labels = feature['{}_label_ids'.format(self.problem_name)]\n",
    "        else:\n",
    "            labels = None\n",
    "        hidden_feature = self.dropout(hidden_feature, training)\n",
    "        logits = self.dense(hidden_feature)\n",
    "\n",
    "        if self.params.detail_log:\n",
    "            for weigth_variable in self.weights:\n",
    "                variable_summaries(weigth_variable, self.problem_name)\n",
    "\n",
    "        if mode != PREDICT:\n",
    "            # labels = tf.squeeze(labels)\n",
    "            # convert labels to one-hot to use label_smoothing\n",
    "            one_hot_labels = tf.one_hot(\n",
    "                labels, depth=self.num_classes)\n",
    "            loss_fn = partial(tf.keras.losses.categorical_crossentropy,\n",
    "                              from_logits=True, label_smoothing=self.params.label_smoothing)\n",
    "\n",
    "            loss = empty_tensor_handling_loss(\n",
    "                one_hot_labels, logits,\n",
    "                loss_fn)\n",
    "            loss = nan_loss_handling(loss)\n",
    "            self.add_loss(loss)\n",
    "            acc = self.metric_fn(labels, logits)\n",
    "            self.add_metric(acc)\n",
    "        return tf.nn.softmax(\n",
    "            logits, name='%s_predict' % self.problem_name)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "# hide\n",
    "from m3tl.test_base import test_top_layer\n",
    "test_top_layer(Classification, problem='weibo_fake_cls', params=params, sample_features=one_batch, hidden_dim=hidden_dim)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "2021-06-24 20:21:23.812 | DEBUG    | m3tl.test_base:test_top_layer:248 - Testing Classification\n",
      "2021-06-24 20:21:23.825 | DEBUG    | m3tl.test_base:test_top_layer:254 - testing batch size 0\n",
      "2021-06-24 20:21:23.825 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train\n",
      "2021-06-24 20:21:23.880 | INFO     | m3tl.utils:set_phase:478 - Setting phase to eval\n",
      "2021-06-24 20:21:23.886 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer\n",
      "2021-06-24 20:21:23.891 | DEBUG    | m3tl.test_base:test_top_layer:254 - testing batch size 1\n",
      "2021-06-24 20:21:23.892 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train\n",
      "2021-06-24 20:21:23.913 | INFO     | m3tl.utils:set_phase:478 - Setting phase to eval\n",
      "2021-06-24 20:21:23.920 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer\n",
      "2021-06-24 20:21:23.925 | DEBUG    | m3tl.test_base:test_top_layer:254 - testing batch size 2\n",
      "2021-06-24 20:21:23.925 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train\n",
      "2021-06-24 20:21:23.985 | INFO     | m3tl.utils:set_phase:478 - Setting phase to eval\n",
      "2021-06-24 20:21:23.991 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Get or make label encoder function\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "source": [
    "# export\n",
    "def cls_get_or_make_label_encoder_fn(params: BaseParams, problem: str, mode: str, label_list: List[str], *args, **kwargs) -> LabelEncoder:\n",
    "\n",
    "    le_path = get_label_encoder_save_path(params=params, problem=problem)\n",
    "    label_encoder = LabelEncoder()\n",
    "    if need_make_label_encoder(mode=mode, le_path=le_path, overwrite=kwargs['overwrite']):\n",
    "        # fit and save label encoder\n",
    "        label_encoder.fit(label_list)\n",
    "        label_encoder.dump(le_path)\n",
    "        params.set_problem_info(problem=problem, info_name='num_classes', info=len(label_encoder.encode_dict))\n",
    "    else:\n",
    "        label_encoder.load(le_path)\n",
    "\n",
    "    return label_encoder"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Label handing function"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "source": [
    "# export\n",
    "def cls_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs):\n",
    "    label_id = label_encoder.transform([target]).tolist()[0]\n",
    "    label_id = np.int32(label_id)\n",
    "    return label_id, None\n",
    "\n"
   ],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}